import random
import os
import pickle as pkl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from layers import MLP
from transformers import BertModel


class BubbleEmbed(nn.Module):
    def __init__(self, args, tokenizer):
        super(BubbleEmbed, self).__init__()

        self.args = args
        self.data = self.__load_data__(self.args.dataset)
        self.FloatTensor = (
            torch.cuda.FloatTensor if self.args.cuda else torch.FloatTensor
        )
        self.concept_set = self.data["concept_set"]
        self.concept_id = self.data["concept2id"]
        self.id_concept = self.data["id2concept"]
        self.id_context = self.data["id2context"]

        self.train_concept_set = list(self.data["train_concept_set"])
        self.train_taxo_dict = self.data["train_taxo_dict"]
        self.train_child_parent_negative_parent_triple = self.data[
            "train_child_parent_negative_parent_triple"
        ]
        self.path2root = self.data["path2root"]
        self.test_concepts_id = self.data["test_concepts_id"]
        self.test_gt_id = self.data["test_gt_id"]

        self.pre_train_model = self.__load_pre_trained__()

        self.projection_center = MLP(
            input_dim=768, hidden=self.args.hidden, output_dim=self.args.embed_size
        )
        self.projection_delta = MLP(
            input_dim=768, hidden=self.args.hidden, output_dim=1
        )

        self.dropout = nn.Dropout(self.args.dropout)

        self.par_chd_loss = nn.MSELoss()
        self.par_chd_negative_loss = nn.MSELoss()
        self.bubble_size_loss = nn.MSELoss()
        self.positive_prob_loss = nn.BCELoss()
        self.negative_prob_loss = nn.BCELoss()

        self.num_dimensions = self.args.embed_size
        self.volume_factor = (math.pi ** (args.embed_size / 2)) / math.gamma((args.embed_size / 2) + 1)

    def __load_data__(self, dataset):
        with open(
            os.path.join(
                "../data/",
                dataset,
                "processed",
                "taxonomy_data_" + str(self.args.expID) + "_.pkl",
            ),
            "rb",
        ) as f:
            data = pkl.load(f)

        return data
    
    def __load_pre_trained__(self):
        model = BertModel.from_pretrained("bert-base-uncased")
        print("Model Loaded!")
        return model

    def parent_child_contain_loss(self, parent_center, parent_delta, child_center, child_delta):
        ones = torch.ones_like(parent_delta)
        margins = ones * self.args.margin  # Allow a small tolerance

        dist_center = self.center_distance(parent_center, child_center).unsqueeze(1)  # Compute center distance
        diff = parent_delta - (child_delta + dist_center)
        loss_mask = (diff<margins).float()
        # Compute loss only where violations occur
        loss = self.par_chd_loss(torch.mul(diff, loss_mask), torch.mul(margins, loss_mask))
        return loss

    def volume_loss(self, parent_delta, child_delta):
        ones = torch.ones_like(parent_delta)
        zeros = torch.zeros_like(parent_delta)

        child_vol = self.bubble_volume(child_delta)
        parent_vol = self.bubble_volume(parent_delta)
        vol_diff = child_vol - parent_vol        
        loss_mask = torch.where(vol_diff>=zeros, ones, zeros)
        diff = torch.mul(vol_diff, loss_mask)
        vol_loss = self.par_chd_loss(diff,zeros)
        return vol_loss
        
    def parent_child_contain_loss_prob(self,parent_center,parent_delta,child_center,child_delta):

        score,_ = self.condition_score(child_center,child_delta,parent_center,parent_delta)
        ones = torch.ones_like(score)
        score = score.clamp(1e-7, 1-1e-7)
        loss = self.positive_prob_loss(score,ones)

        return loss
    
    def center_distance(self, center1, center2):
        return torch.linalg.norm(center1 - center2, 2,-1)
    
    def radial_intersection(self, center1, delta1, center2, delta2):
        dist_center = self.center_distance(center1, center2).unsqueeze(1)
        sum_radius = delta1 + delta2
        mask = (dist_center < sum_radius).float()
        intersection_radius = mask * ((sum_radius - dist_center) / 2)
        intersection_radius = torch.min(intersection_radius, torch.min(delta1, delta2))
        return intersection_radius

    def negative_contain_loss(
        self, child_center, child_delta, neg_parent_center, neg_parent_delta
    ):
        ones = torch.ones_like(neg_parent_delta)
        zeros = torch.zeros_like(neg_parent_delta)

        dist_center = self.center_distance(neg_parent_center, child_center).unsqueeze(1)
        
        sum_radius = neg_parent_delta + child_delta
        epsilon = ones * self.args.epsilon
        
        diff = sum_radius - dist_center
        loss_mask = torch.where(diff > epsilon, ones, zeros)
        
        loss = self.par_chd_negative_loss(torch.mul(diff, loss_mask), torch.mul(epsilon, loss_mask))
        return loss

    def negative_contain_loss_prob(self,child_center,child_delta,neg_parent_center, neg_parent_delta):

        score,_ = self.condition_score(child_center,child_delta,neg_parent_center,neg_parent_delta)
        zeros = torch.zeros_like(score)
        score = score.clamp(1e-7, 1 - 1e-7)
        loss = self.negative_prob_loss(score,zeros)

        return loss

    def projection_bubble(self, encode_inputs):
        cls = self.pre_train_model(**encode_inputs)
        cls = self.dropout(cls[0][:, 0, :])
        center = self.projection_center(cls)
        radius = torch.exp(self.projection_delta(cls)).clamp_min(1e-15)
        return center, radius
    
    
    def bubble_volume(self,delta,temperature=0.1):
        # Ensure valid radii (avoid negative or zero values)
        valid_mask = (delta > 0).float()
        volume = self.volume_factor * (torch.pow(delta,self.num_dimensions))
        return (volume * valid_mask)
    
    def bubble_regularization(self, delta):
        """
        Regularize bubble sizes to ensure they don't get too small.
        Penalizes bubbles with radius smaller than self.args.phi.
        
        Args:
            delta: Tensor of bubble radii
            
        Returns:
            Regularization loss
        """
        zeros = torch.zeros_like(delta)
        ones = torch.ones_like(delta)
        min_radius = torch.ones_like(delta) * self.args.phi
        
        # Create mask for bubbles smaller than minimum size
        small_bubble_mask = torch.where(delta < self.args.phi, ones, zeros)
        
        # Apply mask to focus loss only on small bubbles
        # Calculate MSE between actual and minimum radius for small bubbles
        regular_loss = self.bubble_size_loss(
            torch.mul(delta, small_bubble_mask), 
            torch.mul(min_radius, small_bubble_mask)
        )
        
        return regular_loss

    def condition_score(self, child_center, child_delta, parent_center, parent_delta, temperature=0.1):
        
        inter_delta = self.radial_intersection(
            child_center, child_delta, parent_center, parent_delta
        )
        mask = (inter_delta > 0).float()
        masked_inter_delta = inter_delta * mask # Intersection
        score_pre = masked_inter_delta / child_delta
        scores=score_pre
        parent_volumes = self.bubble_volume(parent_delta)
        return scores.squeeze(), parent_volumes.squeeze()    


    def is_contain(self, child_center, child_delta, parent_center, parent_delta):
        """
        Checks if child bubble (hypersphere) is enclosed inside parent bubble.
        
        A hypersphere is contained in another if the distance between centers
        plus the radius of the inner sphere is less than or equal to the radius
        of the outer sphere.
        
        Args:
            child_center: Center coordinates of child sphere
            child_delta: Radius of child sphere
            parent_center: Center coordinates of parent sphere
            parent_delta: Radius of parent sphere
            
        Returns:
            Float tensor with 1.0 where containment is satisfied, 0.0 elsewhere
        """
        # Calculate distance between centers
        dist_center = self.center_distance(child_center, parent_center).unsqueeze(1)
        
        # Containment condition: dist_center + child_radius <= parent_radius
        # or equivalently: parent_radius - child_radius - dist_center >= 0
        containment = parent_delta - child_delta - dist_center
        mask = (containment >= 0).float()
        
        return mask.squeeze()
    

    def parent_child_contain_loss_cached(self, parent_center, parent_delta, child_center, child_delta, dist_center):
        ones = torch.ones_like(parent_delta)
        margins = ones * self.args.margin  # Allow a small tolerance
        # Corrected enclosure condition: parent must fully contain the child
        diff = parent_delta - (child_delta + dist_center)
        loss_mask = (diff <= margins).float()
        
        # Compute loss only where violations occur
        loss = self.par_chd_loss(torch.mul(diff, loss_mask), torch.mul(margins, loss_mask))
        return loss

    def parent_child_contain_loss_prob_cached(self, parent_center, parent_delta, child_center, child_delta, dist_center):
        score, _ = self.condition_score_cached(child_center, child_delta, parent_center, parent_delta, dist_center)
        ones = torch.ones_like(score)
        score = score.clamp(1e-7, 1-1e-7)
        loss = self.positive_prob_loss(score, ones)
        return loss

    def negative_contain_loss_cached(self, child_center, child_delta, neg_parent_center, neg_parent_delta, dist_center):
        ones = torch.ones_like(neg_parent_delta)
        zeros = torch.zeros_like(neg_parent_delta)
        
        sum_radius = neg_parent_delta + child_delta
        epsilon = ones * self.args.epsilon
        
        diff = sum_radius - dist_center
        loss_mask = torch.where(diff > epsilon, ones, zeros)
        
        loss = self.par_chd_negative_loss(torch.mul(diff, loss_mask), torch.mul(epsilon, loss_mask))
        return loss

    def negative_contain_loss_prob_cached(self, child_center, child_delta, neg_parent_center, neg_parent_delta, dist_center):
        score, _ = self.condition_score_cached(child_center, child_delta, neg_parent_center, neg_parent_delta, dist_center)
        zeros = torch.zeros_like(score)
        # Add clamping for numerical stability
        score = score.clamp(1e-7, 1-1e-7)
        loss = self.negative_prob_loss(score, zeros)
        return loss

    def radial_intersection_cached(self, center1, delta1, center2, delta2, dist_center):
        sum_radius = delta1 + delta2
        mask = (dist_center < sum_radius).float()
        intersection_radius = mask * ((sum_radius - dist_center) / 2)
        intersection_radius = torch.min(intersection_radius, torch.min(delta1, delta2))
        return intersection_radius
    
    def parent_child_log_vol_similarity(self, delta1, delta2):
        min_log_ratio = math.log(self.args.minvol)
        rad_ratio = (delta2/delta1).clamp(min=1e-10)
        log_rad_ratio = torch.log(rad_ratio)
        log_volume_ratio = log_rad_ratio * self.num_dimensions
        # Soft masking for better gradients
        weight = torch.sigmoid(3 * (min_log_ratio - log_volume_ratio))
        penalty = ((min_log_ratio - log_volume_ratio) ** 2) * weight
        return penalty.mean()
    
    def parent_child_volume_similarity(self, delta1, delta2):
        min_ratio = self.args.minvol
        # Ensure valid radii before calculating ratio
        valid_mask = ((delta1 > 0) & (delta2 > 0)).float()
        # Calculate radius ratio first (r_child/r_parent)
        rad_ratio = (delta2/delta1) * valid_mask
        # Then raise to power of dimensions
        volume_ratio = torch.pow(rad_ratio,self.num_dimensions)
        
        # Soft masking for better gradients
        weight = torch.sigmoid(8 * (min_ratio - volume_ratio))
        penalty = ((min_ratio - volume_ratio) ** 2) * weight * valid_mask
        return penalty.mean()


    def condition_score_cached(self, child_center, child_delta, parent_center, parent_delta, dist_center):
        inter_delta = self.radial_intersection_cached(
            child_center, child_delta, parent_center, parent_delta, dist_center
        )
        mask = (inter_delta > 0).float()
        masked_inter_delta = inter_delta * mask
        score_pre = masked_inter_delta / child_delta
        # scores = torch.pow(score_pre, num_dimensions)
        scores = score_pre
        parent_volumes = self.bubble_volume(parent_delta)
        return scores.squeeze(), parent_volumes.squeeze()

    def compute_node_depths(self):
        """Compute depth of each node in the taxonomy using path2root data."""
        depths = {}
        for node_id, path in self.path2root.items():
            # Path length - 1 = depth (root has path length 1)
            depths[node_id] = len(path) - 1
        return depths
    
    def center_contrastive_loss_cached(self, pc_dist, nc_dist):
        loss = torch.mean(F.relu((pc_dist - nc_dist)))
        return loss

    def forward(
        self,
        encode_parent=None,
        encode_child=None,
        encode_negative_parents=None,
        parent_ids=None,
        child_ids=None,
        negative_parent_ids=None,
        flag="train",
    ):
        if flag == "train":
            parent_center, parent_delta = self.projection_bubble(encode_parent)
            child_center, child_delta = self.projection_bubble(encode_child)
            neg_parent_center, neg_parent_delta = self.projection_bubble(
                encode_negative_parents
            )

            pc_dist = self.center_distance(parent_center, child_center).unsqueeze(1)  # parent-child distance
            nc_dist = self.center_distance(neg_parent_center, child_center).unsqueeze(1)  # negative parent-child distance

            # Use the cached distances in loss calculations
            parent_child_vol_loss = self.volume_loss(parent_delta, child_delta)
            # parent_child_vol_loss = self.log_volume_loss(parent_delta, child_delta)

            parent_child_contain_loss = self.parent_child_contain_loss_cached(
                parent_center, parent_delta, child_center, child_delta, pc_dist
            )

            child_parent_negative_loss = self.negative_contain_loss_cached(
                child_center, child_delta, neg_parent_center, neg_parent_delta, nc_dist
            )

            loss_contain = self.args.alpha * (parent_child_contain_loss + parent_child_vol_loss)
            loss_negative = self.args.beta * child_parent_negative_loss


            parent_child_contain_loss_prob = self.parent_child_contain_loss_prob_cached(
                parent_center, parent_delta, child_center, child_delta, pc_dist
            )

            child_parent_negative_loss_prob = self.negative_contain_loss_prob_cached(
                child_center, child_delta, neg_parent_center, neg_parent_delta, nc_dist
            )

            regular_loss = 0
            regular_loss += self.bubble_regularization(parent_delta)
            regular_loss += self.bubble_regularization(child_delta)
            regular_loss += self.bubble_regularization(neg_parent_delta)
            if(self.args.radratio):
                regular_loss += self.parent_child_volume_similarity(parent_delta,child_delta)
                # regular_loss += self.parent_child_log_vol_similarity(parent_delta,child_delta)
            if(self.args.contrastive):
                regular_loss += self.center_contrastive_loss_cached(pc_dist,nc_dist)
            
            regular_loss = self.args.gamma * regular_loss
            loss_pos_prob = self.args.extra*parent_child_contain_loss_prob
            loss_neg_prob = self.args.extra*child_parent_negative_loss_prob

            loss = loss_contain + loss_negative + regular_loss + loss_pos_prob + loss_neg_prob

        return (
            loss,
            loss_contain,
            loss_negative,
            regular_loss,
            loss_pos_prob,
            loss_neg_prob
        )  